from torch.utils.tensorboard import SummaryWriter
import os, utils, glob, losses
import sys
from torch.utils.data import DataLoader
from data import datasets, trans
import numpy as np
import torch
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
from natsort import natsorted
from models.cycleMorph_model import cycleMorph
from models.cycleMorph_model import CONFIGS as CONFIGS


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def MSE_torch(x, y):
    return torch.mean((x - y) ** 2)

def get_default(args, var_name, default_settings, default_value, key_list):
    # var_name has been specified a custom value in command line. So not to use default value instead.
    if (var_name in args) and (args.__dict__[var_name] != default_value):
        return
    v = default_settings
    for k in key_list:
        v = v[k]
    args.__dict__[var_name] = v

def prepare_input(resolution):
    x = torch.FloatTensor(1, *resolution)
    y = torch.FloatTensor(1, *resolution)
    return dict(x=(x,y))

def main():
    batch_size = 1
    atlas_dir = 'Path_to_IXI_data/atlas.pkl'
    train_dir = 'Path_to_IXI_data/Train/'
    val_dir = 'Path_to_IXI_data/Val/'
    save_dir = 'CycleMorph/'

    if not os.path.exists('experiments/'+save_dir):
        os.makedirs('experiments/'+save_dir)
    lr = 0.0001
    epoch_start = 0
    max_epoch = 500
    cont_training = False
    reg_model = utils.register_model((160, 192, 224), 'nearest')
    reg_model.cuda()
    reg_model_bilin = utils.register_model((160, 192, 224), 'bilinear')
    reg_model_bilin.cuda()
    opt = CONFIGS['Cycle-Morph-v0']
    model = cycleMorph()
    model.initialize(opt)

    if cont_training:
        epoch_start = 335
        model_dir = 'experiments/'+save_dir
        updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8)
        best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[0])['state_dict']
        model.netG_A.load_state_dict(best_model)
    else:
        updated_lr = lr

    train_composed = transforms.Compose([trans.RandomFlip(0),
                                         trans.NumpyType((np.float32, np.float32)),
                                         ])

    val_composed = transforms.Compose([trans.Seg_norm(), #rearrange segmentation label to 1 to 46
                                       trans.NumpyType((np.float32, np.int16)),
                                        ])

    train_set = datasets.IXIBrainDataset(glob.glob(train_dir + '*.pkl'), atlas_dir, transforms=train_composed)
    val_set = datasets.IXIBrainInferDataset(glob.glob(val_dir + '*.pkl'), atlas_dir, transforms=val_composed)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
    best_mse = 0
    writer = SummaryWriter(log_dir=save_dir)
    for epoch in range(epoch_start, max_epoch):
        print('Training Starts')
        '''
        Training
        '''
        loss_all = AverageMeter()
        loss_net_a=AverageMeter()
        idx = 0
        for data in train_loader:
            idx += 1
            data = [t.cuda() for t in data]
            x = data[0]
            y = data[1]
            model.set_input([x, y])
            loss_out, loss_reg, loss_net = model.optimize_parameters();
            loss_all.update(loss_out, y.numel())
            loss_net_a.update(loss_net, y.numel())
            print('Iter {} of {} loss {:.4f}, Reg: {:.6f}, loss net a: {:.6f}'.format(idx, len(train_loader), loss_out, loss_reg, loss_net))
        writer.add_scalar('Loss/net_a', loss_net_a.avg, epoch)
        writer.add_scalar('Loss/train', loss_all.avg, epoch)
        print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
        '''
        Validation
        '''
        eval_dsc = AverageMeter()
        with torch.no_grad():
            for data in val_loader:
                data = [t.cuda() for t in data]
                x = data[0]
                y = data[1]
                x_seg = data[2]
                y_seg = data[3]
                grid_img = mk_grid_img(8, 1)
                model.set_input([x, y])
                model.test()
                visuals = model.get_test_data()
                flow = visuals['flow_A']
                def_out = reg_model([x_seg.cuda().float(), flow.cuda()])
                def_grid = reg_model_bilin([grid_img.float(), flow.cuda()])
                dsc = utils.dice_val_VOI(def_out.long(), y_seg.long())
                eval_dsc.update(dsc.item(), x.size(0))
                print(eval_dsc.avg)
        best_mse = max(eval_dsc.avg, best_mse)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.netG_A.state_dict(),
            'best_mse': best_mse,
        }, save_dir='experiments/'+save_dir, filename='dsc{:.3f}.pth.tar'.format(eval_dsc.avg))
        writer.add_scalar('MSE/validate', eval_dsc.avg, epoch)
        plt.switch_backend('agg')
        pred_fig = comput_fig(def_out)
        grid_fig = comput_fig(def_grid)
        x_fig = comput_fig(x_seg)
        tar_fig = comput_fig(y_seg)
        writer.add_figure('Grid', grid_fig, epoch)
        plt.close(grid_fig)
        writer.add_figure('input', x_fig, epoch)
        plt.close(x_fig)
        writer.add_figure('ground truth', tar_fig, epoch)
        plt.close(tar_fig)
        writer.add_figure('prediction', pred_fig, epoch)
        plt.close(pred_fig)
        loss_all.reset()
        loss_net_a.reset()
    writer.close()

def comput_fig(img):
    img = img.detach().cpu().numpy()[0, 0, 48:64, :, :]
    fig = plt.figure(figsize=(12,12), dpi=180)
    for i in range(img.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.axis('off')
        plt.imshow(img[i, :, :], cmap='gray')
    fig.subplots_adjust(wspace=0, hspace=0)
    return fig

def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
    for param_group in optimizer.param_groups:
        param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8)

def mk_grid_img(grid_step, line_thickness=1):
    grid_img = np.zeros((160, 192, 224))
    for j in range(0, grid_img.shape[1], grid_step):
        grid_img[:, j+line_thickness-1, :] = 1
    for i in range(0, grid_img.shape[2], grid_step):
        grid_img[:, :, i+line_thickness-1] = 1
    grid_img = grid_img[None, None, ...]
    grid_img = torch.from_numpy(grid_img).cuda()
    return grid_img

def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
    torch.save(state, save_dir+filename)
    model_lists = natsorted(glob.glob(save_dir + '*'))
    while len(model_lists) > max_model_num:
        os.remove(model_lists[0])
        model_lists = natsorted(glob.glob(save_dir + '*'))

if __name__ == '__main__':
    '''
    GPU configuration
    '''
    GPU_iden = CONFIGS['Cycle-Morph-v0'].gpu_ids[0]
    GPU_num = torch.cuda.device_count()
    print('Number of GPU: ' + str(GPU_num))
    for GPU_idx in range(GPU_num):
        GPU_name = torch.cuda.get_device_name(GPU_idx)
        print('     GPU #' + str(GPU_idx) + ': ' + GPU_name)
    #torch.cuda.set_device(GPU_iden)
    GPU_avai = torch.cuda.is_available()
    print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
    print('If the GPU is available? ' + str(GPU_avai))
    main()
